#include "line2Dup.h"
#include <memory>
#include <iostream>
#include <assert.h>
#include <chrono>
using namespace std;
using namespace cv;


class Timer
{
public:
    Timer() : beg_(clock_::now()) { }
    
    void reset() { beg_ = clock_::now(); }
    
    double elapsed() const
    {
        return std::chrono::duration_cast<second_> (clock_::now() - beg_).count();
    }
    
    void out(std::string message = "")
    {
        double t = elapsed();
        std::cout << message << "\nelasped time:" << t << "s" << std::endl;
        reset();
    }

private:
    typedef std::chrono::high_resolution_clock clock_;
    typedef std::chrono::duration<double, std::ratio<1> > second_;
    std::chrono::time_point<clock_> beg_;
};

// NMS, got from cv::dnn so we don't need opencv contrib
// just collapse it
namespace  cv_dnn
{

namespace
{

template <typename T>
static inline bool SortScorePairDescend(
                          const std::pair<float, T>& pair1,
                          const std::pair<float, T>& pair2 )
{
    return pair1.first > pair2.first;
}

} // namespace

inline void GetMaxScoreIndex(
	const std::vector<float>& scores, 
	const float threshold, 
	const int top_k,
    std::vector<std::pair<float, int> >& score_index_vec )
{
    for (size_t i = 0; i < scores.size(); ++i)
    {
        if (scores[i] > threshold)
        {
            score_index_vec.push_back(std::make_pair(scores[i], i));
        }
    }
    std::stable_sort(score_index_vec.begin(), score_index_vec.end(),
                     SortScorePairDescend<int>);
    if (top_k > 0 && top_k < (int)score_index_vec.size())
    {
        score_index_vec.resize(top_k);
    }
}

template <typename BoxType>
inline void NMSFast_(
        const std::vector<BoxType>& bboxes,
        const std::vector<float>& scores, 
        const float score_threshold,
        const float nms_threshold, 
        const float eta, 
        const int top_k,
        std::vector<int>& indices, 
        float (*computeOverlap)(const BoxType&, const BoxType&) )
{
    CV_Assert(bboxes.size() == scores.size());
    std::vector<std::pair<float, int> > score_index_vec;
    GetMaxScoreIndex(scores, score_threshold, top_k, score_index_vec);

    // Do nms.
    float adaptive_threshold = nms_threshold;
    indices.clear();
    for (size_t i = 0; i < score_index_vec.size(); ++i) {
        const int idx = score_index_vec[i].second;
        bool keep = true;
        for (int k = 0; k < (int)indices.size() && keep; ++k) {
            const int kept_idx = indices[k];
            float overlap = computeOverlap(bboxes[idx], bboxes[kept_idx]);
            keep = overlap <= adaptive_threshold;
        }
        if (keep)
            indices.push_back(idx);
        if (keep && eta < 1 && adaptive_threshold > 0.5) {
          adaptive_threshold *= eta;
        }
    }
}


// copied from opencv 3.4, not exist in 3.0
template<typename _Tp> 
static inline
double jaccardDistance__(
        const Rect_<_Tp>& a,
        const Rect_<_Tp>& b)
{
    _Tp Aa = a.area();
    _Tp Ab = b.area();

    if ((Aa + Ab) <= std::numeric_limits<_Tp>::epsilon())
    {
        // jaccard_index = 1 -> distance = 0
        return 0.0;
    }

    double Aab = (a & b).area();
    // distance = 1 - jaccard_index
    return 1.0 - Aab / (Aa + Ab - Aab);
}

template <typename T>
static inline 
float rectOverlap(
    const T& a, 
    const T& b )
{
    return 1.f - static_cast<float>(jaccardDistance__(a, b));
}


float RotatedRectOverlap(
    const RotatedRect& a,
    const RotatedRect& b)
{
    vector<Point2f> vA;
    vector<Point2f> vB;
    vector<Point2f> inSec;

    Point2f pts[4];

    a.points(pts);

    vA.push_back(pts[0]);
    vA.push_back(pts[1]);
    vA.push_back(pts[2]);
    vA.push_back(pts[3]);

    b.points(pts);

    vB.push_back(pts[0]);
    vB.push_back(pts[1]);
    vB.push_back(pts[2]);
    vB.push_back(pts[3]);
    
    int rst = cv::rotatedRectangleIntersection(a, b, inSec);

    if (inSec.size() == 0)
    {
        return 0.0f;
    }
    else
    {
        float areaA = cv::contourArea(vA);
        float areaB = cv::contourArea(vB);
        float areaInsec = cv::contourArea(inSec);

        return areaInsec / (areaA + areaB - areaInsec);
    }
    

}

void NMSBoxes(
    const std::vector<Rect>& bboxes, 
    const std::vector<float>& scores,
    const float score_threshold, 
    const float nms_threshold,
    std::vector<int>& indices, 
    const float eta=1, 
    const int top_k=0 )
{
    NMSFast_(bboxes, scores, score_threshold, nms_threshold, eta, top_k, indices, rectOverlap);
}

void NMSBoxes(
    const std::vector<RotatedRect>& bboxes,
    const std::vector<float>& scores,
    const float score_threshold,
    const float nms_threshold,
    std::vector<int>& indices,
    const float eta = 1,
    const int top_k = 0)
{
    NMSFast_(bboxes, scores, score_threshold, nms_threshold, eta, top_k, indices, RotatedRectOverlap);
}

}


//static cv::Point2f rotate2d(const cv::Point2f inPoint, const double angRad)
//{
//    cv::Point2f outPoint;
//    //CW rotation
//    outPoint.x = std::cos(angRad)*inPoint.x - std::sin(angRad)*inPoint.y;
//    outPoint.y = std::sin(angRad)*inPoint.x + std::cos(angRad)*inPoint.y;
//    return outPoint;
//}

//static cv::Point2f rotatePoint(const cv::Point2f inPoint, const cv::Point2f center, const double angRad)
//{
//    return rotate2d(inPoint - center, angRad) + center;
//}

// ****************************************************************
//*****************************************************************

static std::string prefix = "D:\\Code\\shape_based_match_win\\data\\";//"C:/Sean/Project/shape_based_matching_win/data/";


void noise_test(string mode = "test")
{
    int num_feature = 128;

    line2Dup::Detector detector(num_feature, {4, 8});

    if(mode == "train"){
        Mat img = imread(prefix+"case2/train.png"); // case1/train.png
        assert(!img.empty() && "check your img path");
        Mat mask = Mat(img.size(), CV_8UC1, {255});

        shape_based_matching::shapeInfo_producer shapes(img, mask);

        shapes.angle_range = {0, 360};
        shapes.angle_step = 5;
        shapes.scale_range = { 0.8f, 1.2f };
        shapes.scale_step = 0.2f;
        shapes.produce_infos();

        std::vector<shape_based_matching::shapeInfo_producer::Info> infos_have_templ;
        
        string class_id = "test";

        for(auto& info: shapes.infos)
        {
            imshow("train", shapes.src_of(info));
            waitKey(1);

            std::cout << "\ninfo.angle: " << info.angle << std::endl;
            int templ_id = detector.addTemplate(shapes.src_of(info), class_id, shapes.mask_of(info));            
            std::cout << "templ_id: " << templ_id << std::endl;
            if(templ_id != -1)
            {
                infos_have_templ.push_back(info);
            }
        }
        detector.writeClasses(prefix+"case2/%s_templ.yaml");
        shapes.save_infos(infos_have_templ, prefix + "case2/test_info.yaml");
        std::cout << "train end" << std::endl << std::endl;
    }
    else if(mode=="test")
    {
        std::vector<std::string> ids;
        ids.push_back("test");
        detector.readClasses(ids, prefix+"case2/%s_templ.yaml");

        std::vector<shape_based_matching::shapeInfo_producer::Info> infos;
        infos = shape_based_matching::shapeInfo_producer::load_infos(prefix + "case2/test_info.yaml");

        Mat test_img = imread(prefix+"case2/test1.png");
        assert(!test_img.empty() && "check your img path");

        // cvtColor(test_img, test_img, CV_BGR2GRAY);

        int stride = 16;
        int n = test_img.rows/stride;
        int m = test_img.cols/stride;
        Rect roi(0, 0, stride*m , stride*n);

        test_img = test_img(roi).clone();

        Timer timer;
        auto matches = detector.match(test_img, 90, ids);
        timer.out();

        std::cout << "matches.size(): " << matches.size() << std::endl;
        size_t top5 = 50;
        if(top5>matches.size()) top5=matches.size();

        vector<Rect> boxes;
        vector<cv::RotatedRect>  Rboxes;
        vector<int> idxsR;
        vector<float> scores;
        vector<int> idxs;

        for(auto match: matches)
        {
            Rect box;
            cv::RotatedRect  Rbox;
            box.x = match.x;
            box.y = match.y;
            
            auto templ = detector.getTemplates("test",
                                               match.template_id);

            box.width = templ[0].width;
            box.height = templ[0].height;

            Rbox.size.width = templ[0].width;
            Rbox.size.height = templ[0].height;
            Rbox.center.x = match.x + templ[0].width/2;
            Rbox.center.y = match.y + templ[0].height/2;
            Rbox.angle = 360 - infos[match.template_id].angle;

            Rboxes.push_back(Rbox);
            boxes.push_back(box);
            scores.push_back(match.similarity);
        }

        cv_dnn::NMSBoxes(boxes, scores, 0, 0.5f, idxs);
        cv_dnn::NMSBoxes(Rboxes, scores, 0, 0.5f, idxsR);

        for(auto idx: idxs){
            auto match = matches[idx];
            auto templ = detector.getTemplates("test",
                                               match.template_id);

            int x =  templ[0].width + match.x;
            int y = templ[0].height + match.y;
            int r = templ[0].width/2;

            float angle = infos[match.template_id].angle;
            float scale = infos[match.template_id].scale;

            cv::Vec3b randColor;
            randColor[0] = rand()%155 + 100;
            randColor[1] = rand()%155 + 100;
            randColor[2] = rand()%155 + 100;

            for(int i=0; i<templ[0].features.size(); i++){
                auto feat = templ[0].features[i];
                cv::circle(test_img, {feat.x+match.x, feat.y+match.y}, 2, randColor, -1);
            }

            cv::putText(test_img, to_string(int(round(match.similarity))),
                        Point(match.x+r-10, match.y-3), FONT_HERSHEY_PLAIN, 2, randColor);
            
            cv::rectangle(test_img, {match.x, match.y}, {x, y}, randColor, 2);            

            std::cout << "\nmatch.template_id: " << match.template_id << std::endl;
            std::cout << "match.similarity: " << match.similarity << std::endl;
        }


        for (auto idxR : idxsR) 
        {
            auto rbox = Rboxes[idxR];
            Point2f pts[4];
            rbox.points(pts);
            Point pt1, pt2;

            for (int i = 0; i < 3; i++)
            {
                pt1.x = (int)pts[i].x;     pt1.y = (int)pts[i].y;
                pt2.x = (int)pts[i + 1].x; pt2.y = (int)pts[i+1].y;
                cv::line(test_img, pt1, pt2, Scalar(0, 255, 0), 2);
            }
            pt1.x = (int)pts[0].x;     pt1.y = (int)pts[0].y;
            pt2.x = (int)pts[3].x;     pt2.y = (int)pts[3].y;
            cv::line(test_img, pt1, pt2, Scalar(0, 255, 0), 2);
        }

        cv::namedWindow("img", 0);
        imshow("img", test_img);
        waitKey(0);

        std::cout << "test end" << std::endl << std::endl;
    }
}

void MIPP_test(){
    std::cout << "MIPP tests" << std::endl;
    std::cout << "----------" << std::endl << std::endl;

    std::cout << "Instr. type:       " << mipp::InstructionType                  << std::endl;
    std::cout << "Instr. full type:  " << mipp::InstructionFullType              << std::endl;
    std::cout << "Instr. version:    " << mipp::InstructionVersion               << std::endl;
    std::cout << "Instr. size:       " << mipp::RegisterSizeBit       << " bits" << std::endl;
    std::cout << "Instr. lanes:      " << mipp::Lanes                            << std::endl;
    std::cout << "64-bit support:    " << (mipp::Support64Bit    ? "yes" : "no") << std::endl;
    std::cout << "Byte/word support: " << (mipp::SupportByteWord ? "yes" : "no") << std::endl;

#ifndef has_max_int8_t
        std::cout << "in this SIMD, int8 max is not inplemented by MIPP" << std::endl;
#endif

#ifndef has_shuff_int8_t
        std::cout << "in this SIMD, int8 shuff is not inplemented by MIPP" << std::endl;
#endif

    std::cout << "----------" << std::endl << std::endl;
}

int main()
{
    MIPP_test();        

 //   angle_test("train", true);
 //   angle_test("test", true); // test or train

	//scale_test("train");
 //   scale_test("test");
    
    noise_test("train");
	noise_test("test");

    return 0;
}
